import os
import argparse
import torch
import time
import random
import wandb
import colossalai
import torch.distributed as dist
import numpy as np
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, DataCollatorForLanguageModeling
from datasets import Dataset
from contextlib import nullcontext

def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", default='train', type=str)
    parser.add_argument("--tp_size", default=2, type=int)
    parser.add_argument("--dataset", default='openwebtext', type=str, help='wikitext or openwebtext')
    parser.add_argument("--model", default='GPT2', type=str, help='GPT2, BERT, Llama2')
    parser.add_argument("--model_name", default='GPT2-L', type=str, help='GPT2-L,XL')
    parser.add_argument("--epoch", default=1, type=int)
    parser.add_argument("--gradient_accumulation", action='store_true')
    parser.add_argument("--gradient_accumulation_value", default=5, type=int)
    parser.add_argument("--gradient_clipping", action='store_true')
    parser.add_argument("--batch", default=16, type=int)
    parser.add_argument("--max_seqlength", default=1024, type=int)
    parser.add_argument("--lr", default=1e-4, type=float)
    parser.add_argument("--weight_decay", default=0.01, type=float)
    parser.add_argument("--hidden_dropout", default=0.0, type=float)
    parser.add_argument("--attn_dropout", default=0.0, type=float)
    parser.add_argument("--eps", default=1e-6, type=float)
    parser.add_argument("--seed", default=42, type=int)
    parser.add_argument("--num_workers", default=8, type=int)
    parser.add_argument("--wandb", action='store_true')
    parser.add_argument("--project_name", default='GPT2-M-openwebtext', type=str)
    parser.add_argument("--start_epoch", default=0, type=int, metavar="N", help="start epoch")
    parser.add_argument("--output_dir", default='./checkpoints_test', type=str, help="Directory to save checkpoints")
    args = parser.parse_args()
    args.device = torch.device(f'cuda')
    return args

def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

def save_dataset(dataset, path):
    dataset.save_to_disk(path)

def load_dataset_from_disk(path):
    return Dataset.load_from_disk(path)

def filter_rows(example):
    return example['text'] != ''

def load_dataset(args):
    train_dataset_path = f"./data/huggingface/{args.dataset}_train_seq:{args.max_seqlength}"
    val_dataset_path = f"./data/huggingface/{args.dataset}_val_seq:{args.max_seqlength}"
    if os.path.exists(train_dataset_path) and os.path.exists(val_dataset_path):
        train_tokenized = load_dataset_from_disk(train_dataset_path)
        eval_tokenized = load_dataset_from_disk(val_dataset_path)
        total_text_num = len(train_tokenized)
    else:
        raise NotImplementedError("the dataset path doesn't exist")

    return train_tokenized, eval_tokenized, total_text_num

def create_dataloaders(args, train_dataset, eval_dataset, data_collator, plugin):
    bs = args.batch // (args.world_size//plugin.tp_size)
    train_loader = plugin.prepare_dataloader(train_dataset, batch_size=bs, shuffle=True, seed=args.seed, num_workers=args.num_workers, drop_last=False, collate_fn=data_collator)
    val_loader = plugin.prepare_dataloader(eval_dataset, batch_size=bs, shuffle=True, seed=args.seed, num_workers=args.num_workers, drop_last=False, collate_fn=data_collator)

    return train_loader, val_loader

def create_model(args):
    if args.model == 'GPT2':
        args.activation = 'gelu_new'
        model_configs = {
            'GPT2-L': {'num_layer': 36, 'num_head': 20, 'hidden_dim': 1280},
            'GPT2-XL': {'num_layer': 48, 'num_head': 24, 'hidden_dim': 1584} # original: 48, 25, 1600
        }
        if args.model_name in model_configs:
            config = model_configs[args.model_name]
            args.num_layer = config['num_layer']
            args.num_head = config['num_head']
            args.hidden_dim = config['hidden_dim']
        else:
            raise ValueError(f"Unknown model_name {args.model_name} for GPT2")

        configuration = GPT2Config(
            n_positions=args.max_seqlength,
            n_embd=args.hidden_dim,
            n_layer=args.num_layer,
            n_head=args.num_head,
            activation_function=args.activation,
            resid_pdrop=args.hidden_dropout,
            attn_pdrop=args.attn_dropout
        )
        model = GPT2LMHeadModel(configuration)
    else:
        raise NotImplementedError(f"Model {args.model} not implemented")
    return model

def create_optimizer(args, model):
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.eps)
    return optimizer

def main():
    args = parse_arguments()
    if torch.cuda.current_device() == 0:
        os.makedirs(args.output_dir, exist_ok=True)
    set_seed(args.seed)
    # os.environ['NCCL_IB_DISABLE'] = '1'
    # os.environ['NCCL_P2P_LEVEL'] = 'NVL'
    colossalai.launch_from_torch()
    args.world_size = int(os.environ.get('WORLD_SIZE', 1))
    store_path = './data/huggingface'

    # Load tokenizer
    if args.model == 'GPT2':
        tokenizer = GPT2Tokenizer.from_pretrained('openai-community/gpt2', cache_dir=store_path)
        tokenizer.pad_token = tokenizer.eos_token
    else:
        raise NotImplementedError(f"Model {args.model} not implemented")

    # Load or prepare dataset
    train_tokenized, eval_tokenized, total_text_num = load_dataset(args)

    # Set up data collator
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    # Set up plugin
    plugin = HybridParallelPlugin(tp_size=args.tp_size, pp_size=1, precision='bf16', max_norm = 1.0, zero_stage= 0)

    # Create data loaders
    train_loader, val_loader = create_dataloaders(args, train_tokenized, eval_tokenized, data_collator, plugin)

    # Create model
    model = create_model(args)

    # Initialize booster
    booster = Booster(plugin=plugin)

    # Initialize wandb
    if torch.cuda.current_device() == 0 and args.wandb:
        ex_name = f"{args.model_name}_TP{plugin.tp_size}_batch:{args.batch}X{args.gradient_accumulation_value}_lr:{args.lr}_BF16_GA"
        wandb.init(project=args.project_name, entity="YOUR ENTITY", name=ex_name)

    if args.mode == 'train':
        # Create optimizer
        optimizer = create_optimizer(args, model)
        # Boost model and optimizer
        model, optimizer, _, train_loader, _ = booster.boost(
            model=model,
            optimizer=optimizer,
            criterion=None,
            dataloader=train_loader,
            lr_scheduler=None
        )

        # Move model to device
        model = model.to(args.device)

        # Print model size and parameter count
        if torch.cuda.current_device() == 0:
            size_model = sum(p.numel() * p.element_size() for p in model.parameters())
            print(model)
            print(f"Model size: {size_model / 1e6:.2f} MB")
            print('Model Parameters:', sum(p.numel() for p in model.parameters()))

        # Training loop
        for epoch in range(args.start_epoch, args.epoch):
            start_time = time.time()
            train_one_epoch(model, train_loader, booster, optimizer, epoch, args.batch, total_text_num, args)
            print(f'\nEpoch {epoch} completed in {time.time() - start_time:.2f} seconds')
            # Optional: Validation step
            val_one_epoch(model, val_loader, epoch, args)
            # if torch.cuda.current_device() == 0:
            booster.save_model(model, os.path.join(args.output_dir, f"checkpoint_model_epoch{epoch}"), shard = True, gather_dtensor=False)
            booster.save_optimizer(optimizer, os.path.join(args.output_dir, f"checkpoint_optimizer_epoch{epoch}"), shard = True, gather_dtensor=False)
    else:
        # Evaluation mode
        pass



def train_one_epoch(model, loader, booster, optimizer, epoch, batch_size, total_text_num, args):
    model.train()
    batch_idx = 0
    total_loss = 0
    optimizer.zero_grad()
    t1 = time.time()
    for batch in loader:
        batch_idx += 1
        input_ids = batch['input_ids'].to(args.device)
        attention_mask = batch['attention_mask'].to(args.device)
        labels = batch['labels'].to(args.device)

        if args.gradient_accumulation:
            accumulation_step = (batch_idx-1) % args.gradient_accumulation_value
            sync_context = booster.no_sync(model) if accumulation_step != 0 else nullcontext()
            with sync_context:
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
                scaled_loss = loss / args.gradient_accumulation_value
                booster.backward(scaled_loss, optimizer)
                total_loss += loss.detach()  # Accumulate unscaled loss
            if accumulation_step == args.gradient_accumulation_value - 1:
                optimizer.step()
                optimizer.zero_grad()
                # Compute average loss over accumulation steps
                total_loss = total_loss / args.gradient_accumulation_value
                dist.reduce(total_loss, dst=0, op=dist.ReduceOp.SUM)
                if torch.cuda.current_device() == 0:
                    total_loss = total_loss / args.world_size
                    ppl = torch.exp(total_loss)
                    progress = (batch_idx / len(loader)) * 100
                    print(f'\r[Epoch {epoch}] Progress: {progress:.3f}%   Train Loss: {total_loss:.3f}   PPL: {ppl:.1f}   Time: {time.time() - t1:.3f}s', end='')
                    if args.wandb:
                        wandb.log({
                            "epoch": epoch,
                            "train loss": total_loss.item(),
                            "train PPL": ppl.item(),
                            "lr": optimizer.param_groups[0]['lr']
                        })
                total_loss = 0
                t1 = time.time()
                # break;
        else:
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            booster.backward(loss, optimizer)
            optimizer.step()
            optimizer.zero_grad()
            ppl = torch.exp(loss)
            if torch.cuda.current_device() == 0:
                progress = (batch_idx / len(loader)) * 100
                print(f'\r[Epoch {epoch}] Progress: {progress:.3f}%   Train Loss: {loss:.3f}   PPL: {ppl:.1f}   Time: {time.time() - t1:.3f}s', end='')
                if args.wandb:
                    wandb.log({
                        "epoch": epoch,
                        "train loss": loss.item(),
                        "train PPL": ppl.item(),
                        "lr": optimizer.param_groups[0]['lr']
                    })
            t1 = time.time()

    print('\r', end='')

@torch.no_grad()
def val_one_epoch(model, loader, epoch, args):
    model.eval()
    loss_cnt = 0
    batch_idx = 0
    start_time = time.time()
    for batch in loader:
        batch_idx += 1
        input_ids = batch['input_ids'].to(args.device)
        attention_mask = batch['attention_mask'].to(args.device)
        labels = batch['labels'].to(args.device)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss_cnt += loss.to(torch.float32)
        # break;
    total_loss = loss_cnt / batch_idx
    dist.reduce(total_loss, dst=0, op=dist.ReduceOp.SUM)
    if torch.cuda.current_device() == 0:
        total_loss /= args.world_size
        ppl = torch.exp(total_loss)
        if args.wandb:
            wandb.log({
                "epoch": epoch,
                "val loss": total_loss.item(),
                "val PPL": ppl.item()
            })
        print(f'\nValidation Loss: {total_loss:.3f}   PPL: {ppl:.1f}   Time: {time.time() - start_time:.3f}s', end='')

if __name__ == "__main__":
    main()
